# Copyright (c) Alibaba, Inc. and its affiliates.
# Code partially sourced from Hugging Face TRL

import asyncio
import inspect
import multiprocessing
import os
import time
import traceback
from collections.abc import Sequence
from contextlib import asynccontextmanager, contextmanager
from dataclasses import asdict
from functools import wraps
from itertools import chain
from multiprocessing import Pipe, Process
from multiprocessing.connection import Connection
from typing import Dict, List, Optional, Union

import torch
import uvicorn
from aiohttp import ClientConnectorError
from fastapi import FastAPI
from trl.scripts.vllm_serve import WeightSyncWorkerExtension as HFWeightSyncWorkerExtension

from swift.llm import RolloutArguments, SwiftPipeline
from swift.llm.template.template_inputs import RolloutInferRequest
from swift.plugin.multi_turn import RolloutScheduler, multi_turns
from swift.trainers.rlhf_trainer.utils import (FlattenedTensorBucket, FlattenedTensorMetadata, TensorLoRARequest,
                                               patch_vllm_load_adapter)
from swift.utils import get_logger
from .infer_engine import GRPOVllmEngine, InferClient
from .protocol import (InitCommunicatorRequest, RequestConfig, UpdateFlattenedAdapterRequest,
                       UpdateFlattenedParamsRequest, UpdateWeightsRequest)

try:
    from vllm.utils import get_open_port
    from trl.scripts.vllm_serve import chunk_list

except ImportError:
    pass
"""
This module defines the execution logic for `swift rollout`.
It adds weight synchronization logic based on `vLLMEngine`.

Usage:
    swift rollout \
        --model xxx \
        --vllm_tensor_parallel_size xxx \
        --vllm_data_parallel_size xxx \
        --vllm_use_async_engine true/false \
        --other_vllm_arguments

Note:
- Rollout is intended solely for GRPO training sampling.
- For inference or deployment, please use the `swift infer` or `swift deploy` commands.
"""

patch_vllm_load_adapter()


class WeightSyncWorkerExtension(HFWeightSyncWorkerExtension):

    def update_named_param(self, name: str, dtype: str, shape: Sequence[int]) -> None:
        """
        Receives updated weights from the client process and updates the named parameter in the model.

        Args:
            name (`str`):
                Name of the weight tensor being updated.
            dtype (`str`):
                Data type of the weight tensor as a string (e.g., `"torch.float32"`).
            shape (`Sequence[int]`):
                Shape of the weight tensor.
        """
        if self._comm is None:
            raise RuntimeError('Communicator not initialized. Call `init_communicator` first.')

        dtype = getattr(torch, dtype.split('.')[-1])
        # Allocate memory for the incoming weight tensor on the correct device.
        weight = torch.empty(shape, dtype=dtype, device=self.device)

        # Use NCCL to broadcast the updated weights from the client (src) to all workers.
        self._comm.broadcast(weight, src=self.client_rank)
        self._comm.group.barrier()

        # Load the received weights into the model.
        self.model_runner.model.load_weights(weights=[(name, weight)])

    def update_adapter_flattened_param(self, lora_int_id: int, peft_config: Dict, metadatas: list[Dict]) -> None:
        """
        Receives and applies a flattened LoRA adapter to the model.
        """
        metadatas = [FlattenedTensorMetadata(**metadata) for metadata in metadatas]
        if self._comm is None:
            raise RuntimeError('Communicator not initialized. Call `init_communicator` first.')
        flatten_tensor_length = metadatas[-1].end_idx
        dtype = getattr(torch, metadatas[-1].dtype.split('.')[-1])
        flatten_tensor = torch.empty(flatten_tensor_length, dtype=dtype, device=self.device)
        self._comm.broadcast(flatten_tensor, src=self.client_rank)
        self._comm.group.barrier()
        flattened_tensor_bucket = FlattenedTensorBucket(metadata=metadatas, flattened_tensor=flatten_tensor)
        named_params = flattened_tensor_bucket.reconstruct_tensors()
        lora_request = TensorLoRARequest(
            lora_name=f'{lora_int_id}',
            lora_int_id=lora_int_id,
            lora_path='dummy_lora_path',
            peft_config=peft_config,
            lora_tensors=named_params)
        self.add_lora(lora_request)

    def update_flattened_params(self, metadatas: list[Dict]) -> None:
        """
        Receives updated flattened weights from the client process and updates the model parameters.

        Args:
            metadatas (list[Dict]): List of metadata dictionaries for the flattened tensors.
        """
        metadatas = [FlattenedTensorMetadata(**metadata) for metadata in metadatas]
        if self._comm is None:
            raise RuntimeError('Communicator not initialized. Call `init_communicator` first.')

        flatten_tensor_length = metadatas[-1].end_idx
        dtype = getattr(torch, metadatas[-1].dtype.split('.')[-1])
        flatten_tensor = torch.empty(flatten_tensor_length, dtype=dtype, device=self.device)

        self._comm.broadcast(flatten_tensor, src=self.client_rank)
        self._comm.group.barrier()

        flattened_tensor_bucket = FlattenedTensorBucket(metadata=metadatas, flattened_tensor=flatten_tensor)
        named_params = flattened_tensor_bucket.reconstruct_tensors()

        # Load the reconstructed parameters into the model
        self.model_runner.model.load_weights(weights=list(named_params.items()))

    @property
    def _comm(self):
        """
        Compatibility wrapper for communicator access across TRL versions.

        Returns the appropriate communicator attribute based on the TRL version:
        - trl < 0.24.0: self.pynccl_comm
        - trl >= 0.24.0: self.communicator
        """
        # Try new version first
        if hasattr(self, 'communicator'):
            return self.communicator
        # Fall back to old version
        elif hasattr(self, 'pynccl_comm'):
            return self.pynccl_comm
        else:
            return None


logger = get_logger()


def safe_set_start_method():
    if multiprocessing.get_start_method(allow_none=True) is None:
        multiprocessing.set_start_method('spawn')


def get_rollout_engine_type(args: RolloutArguments, engine: GRPOVllmEngine):
    if args.multi_turn_scheduler:
        if args.multi_turn_scheduler not in multi_turns:
            raise ValueError(f"Multi-turn scheduler '{args.multi_turn_scheduler}' not found in multi_turns.")
        scheduler_cls = multi_turns[args.multi_turn_scheduler]

        kwargs = {}
        if 'tokenizer' in list(inspect.signature(scheduler_cls.__init__).parameters):
            kwargs['tokenizer'] = engine.default_template.tokenizer
        # gym kwargs
        if args.use_gym_env:
            kwargs.update({
                'use_gym_env': args.use_gym_env,
                'gym_env': args.gym_env,
                'context_manager': args.context_manager,
            })

        rollout_engine: RolloutScheduler = scheduler_cls(infer_engine=engine, max_turns=args.max_turns, **kwargs)
        if not rollout_engine:
            raise ValueError(f"Failed to initialize multi-turn scheduler '{args.multi_turn_scheduler}'.")
    else:
        rollout_engine = engine
    return rollout_engine


def llm_worker(args: RolloutArguments, data_parallel_rank: int, master_port: int, connection: Connection) -> None:
    # Set required environment variables for DP to work with vLLM
    args._import_external_plugins()
    args._init_custom_register()
    os.environ['VLLM_DP_RANK'] = str(data_parallel_rank)
    os.environ['VLLM_DP_RANK_LOCAL'] = str(data_parallel_rank)
    os.environ['VLLM_DP_SIZE'] = str(args.vllm_data_parallel_size)
    os.environ['VLLM_DP_MASTER_PORT'] = str(master_port)
    engine = SwiftRolloutDeploy.get_infer_engine(args, template=args.get_template(None))
    rollout_engine = get_rollout_engine_type(args, engine)
    # Send ready signal to parent process
    connection.send({'status': 'ready'})

    while True:
        # Wait for commands from the parent process
        try:
            command = connection.recv()
        except KeyboardInterrupt:
            engine.engine.collective_rpc(method='close_communicator')
            break

        # Handle commands
        if command['type'] in ['call', 'fire_and_forget']:
            method_name = command['method']
            args, kwargs = command.get('args', ()), command.get('kwargs', {})
            method = getattr(rollout_engine, method_name, None) or getattr(rollout_engine.engine, method_name, None)
            try:
                result = method(*args, **kwargs)
            except Exception:
                logger.error(f'Method execution failed: {method_name}\n{traceback.format_exc()}')
                result = None
            if command['type'] == 'call':
                connection.send(result)
        elif command['type'] == 'shutdown':
            break


async def async_llm_worker(args: RolloutArguments, data_parallel_rank: int, master_port: int,
                           connection: Connection) -> None:
    # Set required environment variables for DP to work with vLLM
    args._import_external_plugins()
    args._init_custom_register()
    engine = SwiftRolloutDeploy.get_infer_engine(args, template=args.get_template(None))

    rollout_engine = get_rollout_engine_type(args, engine)

    # Send ready signal to parent process
    connection.send({'status': 'ready'})

    loop = asyncio.get_running_loop()
    while True:
        try:
            command = await loop.run_in_executor(None, connection.recv)
        except KeyboardInterrupt:
            await engine.engine.collective_rpc(method='close_communicator')
            break

        # Handle commands
        if command['type'] in ['call', 'fire_and_forget']:
            method_name = command['method']
            args, kwargs = command.get('args', ()), command.get('kwargs', {})
            method = getattr(rollout_engine, method_name, None) or getattr(rollout_engine.engine, method_name, None)
            try:
                result = await method(*args, **kwargs)
            except Exception:
                logger.error(f'Method execution failed: {method_name}\n{traceback.format_exc()}')
                result = None

            if command['type'] == 'call':
                connection.send(result)
        elif command['type'] == 'shutdown':
            break


def llm_worker_entry(*args, **kwargs):
    asyncio.run(async_llm_worker(*args, **kwargs))


class SwiftRolloutDeploy(SwiftPipeline):
    args_class = RolloutArguments
    args: args_class

    def _register_rl_rollout_app(self):
        self.app.get('/health/')(self.health)
        self.app.get('/get_world_size/')(self.get_world_size)
        self.app.post('/init_communicator/')(self.init_communicator)
        self.app.post('/update_named_param/')(self.update_named_param)
        self.app.post('/update_adapter_flattened_param/')(self.update_adapter_flattened_param)
        self.app.post('/update_flattened_params/')(self.update_flattened_params)
        self.app.post('/reset_prefix_cache/')(self.reset_prefix_cache)
        self.app.post('/close_communicator/')(self.close_communicator)
        self.app.post('/infer/', response_model=None)(self.infer)
        self.app.post('/get_engine_type/')(self.get_engine_type)

    def __init__(self, args: Optional[Union[List[str], RolloutArguments]] = None):
        super().__init__(args)
        self.use_gym_env = self.args.use_gym_env
        self.use_async_engine = self.args.vllm_use_async_engine
        self.num_connections = 1 if self.use_async_engine else self.args.vllm_data_parallel_size
        safe_set_start_method()
        self.app = FastAPI(lifespan=self.lifespan)
        self._register_rl_rollout_app()
        self.master_port = get_open_port()
        self.connections = []
        self.processes = []
        self._start_data_parallel_workers()

    def _start_data_parallel_workers(self):
        for data_parallel_rank in range(self.num_connections):
            parent_conn, child_conn = Pipe()
            worker_func = llm_worker_entry if self.use_async_engine else llm_worker
            process = Process(target=worker_func, args=(self.args, data_parallel_rank, self.master_port, child_conn))
            process.start()
            self.connections.append(parent_conn)
            self.processes.append(process)

    @asynccontextmanager
    async def lifespan(self, app: FastAPI):
        # Wait for all workers to send "ready"
        ready_connections = set()

        while len(ready_connections) < self.num_connections:
            for connection in self.connections:
                msg = connection.recv()
                if isinstance(msg, dict) and msg.get('status') == 'ready':
                    ready_connections.add(connection)

        yield

        # Wait for processes to terminate
        for process in self.processes:
            process.join(timeout=10)  # Wait for 10 seconds for the process to terminate
            if process.is_alive():
                logger.warning(f'Process {process} is still alive after 10 seconds, attempting to terminate...')
                process.terminate()
                process.join()  # ensure process termination after calling terminate()

    @staticmethod
    def get_infer_engine(args: RolloutArguments, template=None, **kwargs):
        kwargs.update({
            'model_id_or_path': args.model,
            'model_type': args.model_type,
            'revision': args.model_revision,
            'torch_dtype': args.torch_dtype,
            'template': template,
            'use_async_engine': args.vllm_use_async_engine,
            'max_lora_rank': args.vllm_max_lora_rank,
        })
        infer_backend = kwargs.pop('infer_backend', None) or args.infer_backend
        if infer_backend != 'vllm':
            infer_backend = 'vllm'
            logger.info('Currently, rollout only supports the vLLM backend. Set vLLM backend')
        kwargs.update(args.get_vllm_engine_kwargs())
        kwargs.update({'enable_lora': args.vllm_enable_lora})  # override
        # used for RL external rollout backend
        engine_kwargs = kwargs.get('engine_kwargs', {})
        # for RL rollout model weight sync
        engine_kwargs.update({'worker_extension_cls': 'swift.llm.infer.rollout.WeightSyncWorkerExtension'})
        engine_kwargs['load_format'] = 'dummy'
        if args.vllm_use_async_engine and args.vllm_data_parallel_size > 1:
            engine_kwargs['data_parallel_size'] = args.vllm_data_parallel_size
        kwargs['engine_kwargs'] = engine_kwargs

        return GRPOVllmEngine(**kwargs)

    async def health(self):
        """
        Health check endpoint to verify that the server is running.
        """
        return {'status': 'ok'}

    async def get_world_size(self):
        """
        Retrieves the world size from the LLM engine.

        Returns:
            `dict`:
                A dictionary containing the world size.

        Example response:
        ```json
        {"world_size": 8}
        ```
        """
        return {'world_size': self.args.vllm_tensor_parallel_size * self.args.vllm_data_parallel_size}

    async def init_communicator(self, request: InitCommunicatorRequest):
        """
        Initializes the communicator for synchronizing model weights between a client and multiple server
        workers.

        Args:
            request (`InitCommunicatorRequest`):
                - `host` (`str`): Hostname or IP address of the master node.
                - `port` (`int`): Port number to be used for communication.
                - `world_size` (`int`): Total number of participating processes in the group.
        """
        world_size = self.args.vllm_tensor_parallel_size * self.args.vllm_data_parallel_size + 1

        # The function init_communicator is called this way: init_communicator(host, port, world_size)
        # So with collective_rpc we need to call it this way:
        # llm.collective_rpc(method="init_communicator", args=(host, port, world_size))
        kwargs = {
            'method':
            'init_communicator',
            'args': (request.host, request.port, world_size, *(() if request.client_device_uuid is None else
                                                               (request.client_device_uuid, )))
        }
        for connection in self.connections:
            connection.send({'type': 'fire_and_forget', 'method': 'collective_rpc', 'kwargs': kwargs})

        return {'message': 'Request received, initializing communicator'}

    async def update_named_param(self, request: UpdateWeightsRequest):
        """
        Updates the model weights with the provided tensor.

        Once this endpoint is called, the client process should broadcast the updated weights to all server workers.

        Args:
            request (`UpdateWeightsRequest`):
                - `name` (`str`): Name of the weight tensor being updated.
                - `dtype` (`str`): Data type of the weight tensor (e.g., `"torch.float32"`).
                - `shape` (list of `int`): Shape of the weight

        """
        # The function update_named_param is called this way: update_named_param("name", torch.float32, (10, 10))
        # So with collective_rpc we need to call it this way:
        # llm.collective_rpc("update_named_param", args=("name", torch.float32, (10, 10)))
        kwargs = {'method': 'update_named_param', 'args': (request.name, request.dtype, tuple(request.shape))}
        for connection in self.connections:
            connection.send({'type': 'fire_and_forget', 'method': 'collective_rpc', 'kwargs': kwargs})

        return {'message': 'Request received, updating named parameter'}

    async def update_adapter_flattened_param(self, request: UpdateFlattenedAdapterRequest):
        peft_config = asdict(request.peft_config)
        metadatas = [
            metadata.model_dump() if hasattr(metadata, 'model_dump') else metadata.dict()
            for metadata in request.metadatas
        ]
        kwargs = {'method': 'update_adapter_flattened_param', 'args': (request.lora_int_id, peft_config, metadatas)}
        for connection in self.connections:
            connection.send({'type': 'fire_and_forget', 'method': 'collective_rpc', 'kwargs': kwargs})

        return {'message': 'Request received, updating adapter parameter'}

    async def update_flattened_params(self, request: UpdateFlattenedParamsRequest):
        """
        Updates the model weights with flattened tensor data.

        Args:
            request (UpdateFlattenedParamsRequest):
                - metadatas (List[FlattenedTensorMetadata]): Metadata for the flattened tensors.

        """
        metadatas = [
            metadata.model_dump() if hasattr(metadata, 'model_dump') else metadata.dict()
            for metadata in request.metadatas
        ]
        kwargs = {'method': 'update_flattened_params', 'args': (metadatas, )}
        for connection in self.connections:
            connection.send({'type': 'fire_and_forget', 'method': 'collective_rpc', 'kwargs': kwargs})

        return {'message': 'Request received, updating flattened parameters'}

    async def reset_prefix_cache(self):
        """
        Resets the prefix cache for the model.
        """
        for connection in self.connections:
            connection.send({'type': 'call', 'method': 'reset_prefix_cache'})
        # Wait for and collect all results
        all_outputs = [connection.recv() for connection in self.connections]
        success = all(output for output in all_outputs)
        return {'message': 'Request received, resetting prefix cache status: ' + str(success)}

    async def get_engine_type(self):
        """
        Return a dictionary describing the runtime engine configuration.

        The returned object contains three keys:
        - engine_type (str): Either 'AsyncLLMEngine' or 'LLMEngine', indicating
        whether the asynchronous or synchronous engine is in use.
        - use_gym_env (bool, optional): Present and True **only when**
        ``use_async_engine`` and ``use_gym_env`` are both True.
        - enable_multi_turn (bool): True if multi-turn scheduling is enabled
        via ``args.multi_turn_scheduler``, otherwise False.

        Returns
        -------
        dict
            A concise specification of the current engine setup.
        """
        enable_multi_turn = False
        if self.args.multi_turn_scheduler:
            enable_multi_turn = True
        use_gym_env = False
        if self.use_async_engine and self.use_gym_env:
            use_gym_env = True
        engine_type = 'AsyncLLMEngine' if self.use_async_engine else 'LLMEngine'
        enable_lora = self.args.vllm_enable_lora
        return {
            'engine_type': engine_type,
            'enable_multi_turn': enable_multi_turn,
            'use_gym_env': use_gym_env,
            'enable_lora': enable_lora,
        }

    async def close_communicator(self):
        """
        Closes the weight update group and cleans up associated resources.
        """
        kwargs = {'method': 'close_communicator'}
        for connection in self.connections:
            connection.send({'type': 'fire_and_forget', 'method': 'collective_rpc', 'kwargs': kwargs})
        return {'message': 'Request received, closing communicator'}

    async def infer(
        self,
        infer_requests: List[Union[Dict, RolloutInferRequest]],
        request_config: Optional[RequestConfig] = None,
        *,
        use_tqdm: Optional[bool] = None,
    ):
        chunked_infer_requests = chunk_list(infer_requests, self.num_connections)

        # Send the prompts to each worker
        for i, (connection, requests) in enumerate(zip(self.connections, chunked_infer_requests)):
            # When the number of prompts is less than data_parallel_size, some workers will receive empty prompts.
            # However, vLLM requires that we always send at least one prompt. So we send a placeholder prompt to comply
            # with vLLM's requirement, and we later ignore the result.
            if not requests:
                requests = [RolloutInferRequest(messages=[{'role': 'user', 'content': '<placeholder>'}])]
            # different seed bewteen vLLM Engine
            if request_config.seed:
                request_config.seed += i * len(requests)
            kwargs = {'infer_requests': requests, 'request_config': request_config, 'use_tqdm': use_tqdm}
            method = 'infer' if not self.use_async_engine else 'async_infer'
            connection.send({'type': 'call', 'method': method, 'kwargs': kwargs})

        all_outputs = [connection.recv() for connection in self.connections]
        # Handle empty prompts (see above)
        all_outputs = [output for output, requests in zip(all_outputs, chunked_infer_requests) if requests]
        all_outputs = list(chain.from_iterable(all_outputs))  # from list of list to single list

        return all_outputs

    def run(self):
        args = self.args
        uvicorn.run(self.app, host=args.host, port=args.port, log_level=args.log_level)


def rollout_main(args: Optional[Union[List[str], RolloutArguments]] = None) -> None:
    SwiftRolloutDeploy(args).main()


def is_accessible(port: int):
    infer_client = InferClient(port=port)
    try:
        infer_client.get_model_list()
    except ClientConnectorError:
        return False
    return True


@contextmanager
def run_rollout(args: RolloutArguments, return_url: bool = False):
    if isinstance(args, RolloutArguments) and args.__class__.__name__ == 'RolloutArguments':
        deploy_args = args
    else:
        args_dict = asdict(args)
        parameters = inspect.signature(RolloutArguments).parameters
        for k in list(args_dict.keys()):
            if k not in parameters or args_dict[k] is None:
                args_dict.pop(k)
        deploy_args = RolloutArguments(**args_dict)

    mp = multiprocessing.get_context('spawn')
    process = mp.Process(target=rollout_main, args=(deploy_args, ))
    process.start()
    try:
        while not is_accessible(deploy_args.port):
            time.sleep(1)
        yield f'http://127.0.0.1:{deploy_args.port}/v1' if return_url else deploy_args.port
    finally:
        process.terminate()
        logger.info('The deployment process has been terminated.')
